import jsonlines
import re
from tqdm.contrib import tzip
import sys



prediction_file = sys.argv[1]

data = [d for d in jsonlines.open(prediction_file, "r")]
data2 = [d for d in jsonlines.open(f"./data/ecare/ecare_no_shuffle.jsonl", "r")]
labels = [d for d in jsonlines.open("./data/ecare/ecare.jsonl", "r")]

direct = "You are given one question and two options, you should choose a more plausible option to answer the question properly. You answer format should be \"Answer: (A|B)\""
user = "You are given one question, two options, and a rule. You should refer to the rule and choose a more plausible option to answer the question properly."
retrieve = "You are given one question, two options, and some rules. You should refer to the rules and choose a more plausible option to answer the question properly."
judge = "You are given one question, two options, and one rule. You should judge whether the rule can help to answer this question. You answer format should be \"Answer: (Yes|No)\""
recall = "You are given one question and two options. You should recall a rule first, then refer to the rule and choose a more plausible option to answer the question properly. You answer format should be \"Rule:"

task_results = {task: {"hit": 0, "count": 0} for task in ["direct", "user", "retrieve", "judge", "recall"]}
results = {}

pattern_mapping = {"direct": re.compile("\(A|B\)"),  
                   "user": re.compile("\(A|B\)"), 
                   "retrieve": re.compile("\(A|B\)"),
                   "judge": re.compile("Yes|No"),
                   "recall": re.compile("\(A|B\)")}

indexes = [1, 2533, 5995, 13243, 2, 12476, 19448, 20259, 5, 1163, 3608, 13257, 21134, 13, 1909, 6198, 10109, 16163, 17761, 28, 2245, 5796, 7165, 15808, 17222, 30, 11854, 16931, 19714, 37, 3772, 6345, 6780, 8632, 13955, 16608, 44, 5788, 6612, 14323, 21309, 71, 4026, 6139, 15293, 81, 2465, 3787, 11545, 16182, 17215, 17653, 83, 15692, 18651, 18703, 91, 10194, 15171, 16860, 93, 4303, 5811, 13354, 100, 14339, 15086, 17311, 104, 1406, 7626, 10432, 16020, 106, 862, 12027, 20625, 118, 4419, 17987, 20932, 124, 6405, 10282, 11750, 14669, 16385, 144, 5074, 5631, 16509, 17295, 19227, 148, 3100, 4939, 15469, 150, 2187, 9889, 10396, 17235, 159, 2201, 3369, 13410, 161, 2258, 2838, 3805, 163, 4827, 11546, 13371, 167, 4503, 5525, 15149, 173, 2229, 11516, 17782, 182, 5336, 8705, 11208, 11224, 184, 1519, 2207, 13560, 188, 2607, 3320, 11574, 189, 5045, 7394, 17657, 208, 12112, 14202, 15490, 215, 1712, 1817, 6215, 20562, 216, 3793, 5819, 8535, 9856, 21038, 225, 2234, 9269, 10444, 15487, 20020, 235, 2831, 3666, 11322, 16503, 18890, 247, 9228, 11945, 12705, 256, 10278, 10746, 12497, 13396, 264, 1912, 3784, 9049, 15427, 276, 1567, 10600, 12283, 299, 805, 5739, 19933, 311, 1725, 16109, 18520, 339, 415, 5472, 12163, 347, 15633, 19809, 20150, 364, 7453, 7712, 11303, 365, 3099, 4086, 10705, 13408, 15537, 20679, 374, 4210, 8128, 12491, 16467, 386, 2783, 6411, 7686, 10752, 390, 8638, 13136, 14119, 20261, 417, 635, 4241, 8404, 8441, 17395, 18910, 424, 11066, 16423, 16528, 438, 6367, 9552, 16727, 445, 6253, 17727, 20571, 20953, 453, 2999, 15055, 17104, 18926, 454, 1505, 2221, 15014, 17094, 17789, 20343, 20778, 457, 5554, 11155, 13994, 466, 4414, 10966, 19593, 469, 7465, 7569, 11588, 20955, 471, 4154, 4255, 6649, 472, 2519, 4790, 9671, 491, 522, 15751, 17169, 498, 2119, 4808, 18506, 19586, 503, 3169, 7746, 20260, 510, 7428, 7944, 12676, 525, 8456, 9901, 13281, 528, 1524, 8896, 17139, 538, 5917, 8342, 19643, 20708, 546, 4062, 6729, 7597, 573, 1349, 10231, 12267, 577, 1502, 13164, 17254, 581, 3683, 5432, 14210, 592, 618, 3057, 9532, 619, 2030, 2640, 12327, 623, 971, 3831, 7786, 630, 2275, 3137, 11216, 656, 1231, 2836, 20326, 657, 9067, 9499, 17863, 658, 5857, 6055, 13454, 660, 7260, 15981, 18349, 667, 12994, 16571, 19222, 694, 1257, 4142, 8956, 706, 1259, 3317, 7599, 707, 1346, 11029, 12255, 15465, 15829, 728, 4683, 5243, 9912, 10535, 12082, 20891, 742, 7863, 11234, 18218, 756, 1030, 1324, 7894, 764, 8409, 11372, 21177, 769, 12660, 15574, 16958, 770, 6272, 11554, 19907, 20372, 787, 2040, 3538, 14649, 793, 7068, 8166, 12787, 14540, 799, 8408, 13630, 17636, 831, 5659, 8060, 11090, 15817, 832, 5716, 6823, 9475, 14281, 854, 1473, 2531, 5903, 12685, 16436, 18638, 857, 9132, 17529, 20437, 858, 1475, 3710, 8111, 880, 5319, 9655, 12641, 19447, 881, 3976, 7669, 9391, 15015, 15108, 17220, 18237, 20609, 899, 1486, 16253, 16865, 914, 8013, 13745, 16328, 919, 1720, 5913, 9628, 12779, 17659, 923, 1980, 3679, 19011, 932, 5238, 13030, 15656, 15945, 944, 4915, 8993, 19548, 956, 1336, 20600, 21126, 972, 5515, 8596, 10045, 978, 3627, 10870, 17287, 993, 12823, 17730, 20068, 1013, 1356, 4423, 8945, 1014, 8070, 10303, 11278, 1017, 5007, 13543, 17711, 1019, 1360, 9414, 16629, 1029, 5328, 6094, 12284, 19349, 20984, 1038, 1454, 13409, 16806, 18049, 18630, 1042, 1908, 5455, 18068, 1066, 4555, 10557, 20309, 1068, 7337, 14182, 16439, 1088, 7001, 10180, 15270, 1091, 1537, 2015, 2955, 8357, 16337, 1093, 12278, 14486, 20269, 1138, 1620, 3377, 7474, 1142, 17514, 19351, 21214, 21300, 1148, 1887, 11456, 18512, 1151, 2529, 2977, 3817, 6634, 13574, 13721, 14611, 1156, 7568, 8578, 18447, 1182, 1351, 13945, 16154, 1191, 3808, 4770, 12157, 1202, 8592, 13273, 14916, 20838, 1208, 8957, 12589, 13575, 14686, 1215, 2449, 6589, 8619, 11987, 1230, 2585, 13869, 16847, 1241, 10863, 14351, 14904, 1246, 6166, 12560, 18509, 1247, 6125, 16453, 17440, 1254, 2807, 11093, 13303, 19053, 1262, 1271, 1554, 6208, 6266, 1278, 1772, 6007, 11194, 1299, 5924, 7023, 14053, 1307, 10636, 14191, 20096, 1327, 15516, 17132, 20935, 1335, 4150, 16456, 17413, 17710, 1354, 4118, 5947, 11985, 1357, 4649, 16416, 18697, 1358, 11772, 13889, 16519, 1374, 2034, 3001, 11982, 1383, 2208, 7379, 9962, 1418, 4830, 6962, 15170, 19241, 1422, 6009, 17582, 20846, 1424, 3987, 11831, 18505, 1434, 5591, 6763, 7066, 1438, 10302, 12315, 14150, 1444, 1698, 2726, 6605, 15672, 1445, 5736, 8631, 14482, 16616, 1479, 2008, 18974, 21113, 1489, 2841, 3298, 16990, 19498, 1504, 2417, 2571, 10191, 10649, 12600, 13596, 17592, 1525, 4856, 13692, 17353, 1542, 2877, 4753, 8954, 11261, 11825, 1551, 6691, 12483, 14967, 1570, 4572, 11826, 13182, 13327, 1571, 14901, 15463, 16885, 1578, 6218, 10329, 20043, 20477, 1581, 2855, 15547, 17909, 1592, 1732, 7033, 8931, 1594, 2070, 4308, 15333, 20474, 20693, 1598, 3716, 9448, 16530, 1599, 9546, 9714, 14747, 1610, 8950, 14203, 15201, 1615, 9521, 11323, 12329, 1630, 1730, 4601, 13370, 19848, 1631, 8990, 10516, 19345, 20026, 1637, 5306, 11555, 21094, 1641, 3982, 7096, 11195, 1645, 4250, 12076, 15514, 16561, 18057, 19893, 1649, 7113, 7531, 20298, 1667, 5580, 7424, 14091, 1669, 5684, 11637, 16843, 1683, 4249, 10206, 12927, 15950, 1703, 3080, 4595, 14497, 16250, 1713, 4123, 4136, 9038, 1740, 7575, 16488, 17762, 1759, 2473, 10611, 12526, 13070, 1760, 7350, 7825, 18718, 19410, 1764, 4921, 10887, 13306, 13433, 17883, 1776, 5962, 8045, 10054, 11419, 12068, 1778, 5557, 10884, 18960, 1782, 3107, 3430, 12501, 13902, 1786, 4812, 6754, 10491, 11386, 1804, 2650, 5657, 11156, 13835, 19503, 1812, 11956, 15853, 18947, 1830, 4351, 8674, 18191, 1837, 1883, 9098, 19859, 1839, 2231, 2779, 9488, 18723, 1844, 8850, 18542, 19507, 1856, 7650, 8389, 9618, 1864, 2761, 7682, 11796, 1876, 6749, 13894, 15734, 1881, 7759, 10888, 19254, 1898, 9423, 12840, 12947, 15828, 17787, 19831, 1900, 8470, 8612, 15204, 1914, 6463, 8624, 10833, 11497, 19675, 1944, 5846, 17418, 21173, 1945, 12065, 12845, 13434, 20812, 1948, 3691, 5140, 20492, 1969, 12246, 13447, 18717, 1978, 6439, 14215, 17685, 1986, 2552, 3932, 8394, 2019, 11526, 18845, 20865, 2024, 5842, 8305, 13981, 2033, 3970, 4391, 6502, 2066, 3270, 5178, 6164, 10298, 10977, 2068, 6370, 12766, 19529, 2069, 3018, 11694, 17903, 2084, 11928, 12703, 20594, 2092, 3611, 5366, 10642, 16830, 2108, 12514, 14014, 16367, 2122, 4360, 7237, 8623, 12841, 2140, 4304, 4860, 5407, 17600, 21253, 2168, 2287, 4121, 6925, 18735]
rule_data = [d for d in jsonlines.open("./data/ecare/all_full.jsonl", "r")]
rules = [d['conceptual_explanation'] for i, d in enumerate(rule_data) if i in indexes]

direct_predictions, judge_predictions = [], []
directs, judges = [], []

for example, label_example, rule in tzip(data, labels, rules):
    response = example["R"].replace("<|im_start|> assistant", "<|im_start|>assistant").split("<|im_start|>assistant")[1].strip()
    if direct in example["input"]:
        task = "direct"
        directs.append(label_example)
    elif user in example["input"]:
        task = "user"
    elif retrieve in example["input"]:
        task = "retrieve"
    elif judge in example["input"]:
        task = "judge"
        judges.append(label_example)
    else:
        task = "recall"
    pattern = pattern_mapping[task]
    
    try:
        prediction = re.findall(pattern, response)[0]
    except:
        prediction = "None"

    label = re.findall(pattern, label_example["output"])[0]

    if task == 'direct':
        if rule not in results:
            results[rule] = {"total": 0, "correct": 0}
        results[rule]["total"] += 1

        if prediction == label:
            direct_predictions.append(1)
            results[rule]["correct"] += 1
        else:
            direct_predictions.append(0)

    if task == 'judge':
        if prediction == label:
            judge_predictions.append(1)
        else:
            judge_predictions.append(0)

    if prediction == label:
        task_results[task]["hit"] += 1
    task_results[task]["count"] += 1


count = sum([results[r]["correct"] / results[r]["total"] for r in results])


for task in task_results:
    print(f"Task: {task}")
    print(f"Accuracy: {task_results[task]['hit'] / task_results[task]['count']}")
    print(f"Count: {task_results[task]['count']}")
    print(f"Hit: {task_results[task]['hit']}")
    
    print()

print(f"Task Accuracy: {count/len(results)}")

double_count = 0
for i in range(1005):
    drt, jdg = data2[i*5], data2[i*5+3]
    index1, index2 = directs.index(drt), judges.index(jdg)
    if direct_predictions[index1] == judge_predictions[index2] == 1:
        double_count += 1

print(f"Double Accuracy: {double_count/1005}")





